import os
import jax
import optax
import jax.numpy as jnp
import jax.random as random
import wandb
import pickle
import data
import model
import loss_functions as loss_f
from sklearn.manifold import Isomap
import evaluation

def train_isometry(args):
    run = wandb.init(project="PROJECT",config=args)

    key = random.PRNGKey(args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    
    x, mean = data.load_data(debug=True)
    isomap = Isomap(n_neighbors=args.n_neighbors).fit(x)
    dij = isomap.dist_matrix_
    dij_max = dij.max()
    dij_diff_max =  (dij[:,jnp.newaxis] -  dij[jnp.newaxis]).max()

    if args.split:
        train_idx = random.choice(rs_key,x.shape[0],(int(x.shape[0]*0.8),),replace=False)
        test_idx = jnp.setdiff1d(jnp.arange(x.shape[0]), train_idx)
        train_dataloader = data.ARCHDataLoader(x[train_idx],dij[train_idx][:,train_idx],args,key=dataloader_key)
        test_dataloader = data.ARCHDataLoader(x[test_idx],dij[test_idx][:,test_idx],args,key=dataloader_key)
    else:
        train_dataloader = data.ARCHDataLoader(x,dij,args,key=dataloader_key)
        test_dataloader = train_dataloader
    train_sample, _ = next(iter(train_dataloader))
    
    varphi = model.CNF(x_dim=x.shape[1],hidden_nf=args.hidden_nf,n_layers=args.n_layers,n_steps=args.n_steps,seed=args.seed)
    pars = varphi.init(init_key,train_sample)

    warmup = args.warmup
    optim = optax.adam(learning_rate=args.lr)
    opt_state = optim.init(pars)

    @jax.jit
    def compute_loss(pars,x,dij,key,epoch):
        z, vfs, e_vjps, jac_trace_est = varphi.apply(pars,x)
        z_inv, _ = varphi.apply(pars,z,method="inverse")
        d_zij = jnp.sqrt(jnp.sum((z[None] - z[:,None])**2,axis=-1)+1e-8)

        global_loss = loss_f.global_loss(dij,d_zij,dij_max)
        gm_loss = loss_f.graph_matching_loss(dij,d_zij,dij_diff_max)
        ld_loss = loss_f.low_dimensional_loss(z,args.n_dim)
        kinetic_loss = loss_f.kinetic_energy_loss(vfs)
        jacobian_loss = loss_f.jacobian_loss(e_vjps)
        low_rank_loss = loss_f.low_rank_loss(x,z,varphi,pars,args)
        inverse_loss = loss_f.inverse_loss(x,z_inv)

        loss = jax.lax.cond(epoch>=warmup,
                        lambda _: args.alpha1*global_loss + args.alpha2*gm_loss + args.alpha3*ld_loss + args.alpha4*jacobian_loss,
                        lambda _: args.alpha3*ld_loss + args.alpha4*jacobian_loss,
                        operand=None)
        
        return loss, (global_loss, gm_loss, ld_loss, kinetic_loss, jacobian_loss, low_rank_loss, inverse_loss)
    
    @jax.jit
    def update_params(pars,x,dij,opt_state,key,epoch):
        (loss, losses), grad = jax.value_and_grad(compute_loss,has_aux=True)(pars,x,dij,key,epoch)
        updates, opt_state = optim.update(grad,opt_state)
        pars = optax.apply_updates(pars,updates)
        return loss, losses, pars, opt_state
    
    @jax.jit
    def model_eval(pars,x):
        z, _, _, _ = varphi.apply(pars,x)
        x_approx, _ = varphi.apply(pars,z,method="inverse")
        return z, x_approx
    
    best_loss = float("inf")
    best_pars = pars
    for i in range(args.epochs):
        train_loss = 0.0
        train_losses = jnp.zeros(7)
        for j, (batch_x,batch_dij) in enumerate(train_dataloader):
            key = random.split(key)[0]
            loss_batch, losses_batch, pars, opt_state = update_params(pars,batch_x,batch_dij,opt_state,key,i)
            train_loss += loss_batch
            train_losses += jnp.array(losses_batch)
        train_loss = train_loss/len(train_dataloader)
        train_losses = train_losses/len(train_dataloader)
        run.log({"epoch":i,"loss":train_loss,"global loss":train_losses[0],"gm loss":train_losses[1],"ld loss":train_losses[2],"kinetic energy loss":train_losses[3],"jacobian loss":train_losses[4],"low rank loss":train_losses[5],"inverse loss":train_losses[6]})
        test_metrics = jnp.zeros(3)
        for k, (batch_x,batch_dij) in enumerate(test_dataloader):
            batch_z, batch_x_approx = model_eval(pars,batch_x)
            metrics_batch = jnp.array([evaluation.invertibility(batch_x,batch_x_approx),evaluation.low_dimensionality(batch_z,args.n_dim),evaluation.isometry(batch_z,batch_dij,dij_max)])
            test_metrics += metrics_batch
        test_metrics = test_metrics/len(test_dataloader)
        run.log({"epoch":i,"invertibility":test_metrics[0],"low dimensionality":test_metrics[1],"isometry":test_metrics[2]})
        if train_loss < best_loss:
            best_loss = train_loss
            best_pars = pars
    with open(os.path.join(wandb.run.dir, "parameters"), 'wb') as fp:
        pickle.dump(best_pars,fp)
    run.finish()

def train_VAE(args):
    run = wandb.init(project="PROJECT",config=args)

    key = random.PRNGKey(args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    
    x,mean = data.load_data(debug=True)
    isomap = Isomap(n_neighbors=args.n_neighbors).fit(x)
    dij = isomap.dist_matrix_

    if args.split:
        train_idx = random.choice(rs_key,x.shape[0],(int(x.shape[0]*args.train_size),),replace=False)
        test_idx = jnp.setdiff1d(jnp.arange(x.shape[0]), train_idx)
        train_dataloader = data.ARCHDataLoader(x[train_idx],dij[train_idx][:,train_idx],args,key=dataloader_key)
        test_dataloader = data.ARCHDataLoader(x[test_idx],dij[test_idx][:,test_idx],args,key=dataloader_key)
    else:
        train_dataloader = data.ARCHDataLoader(x,dij,args,key=dataloader_key)
        test_dataloader = train_dataloader
    train_sample, _ = next(iter(train_dataloader))

    vae = model.VAE(x_dim=x.shape[1],n_encoder_layers=args.n_layers,n_decoder_layers=args.n_layers,hidden_nf=args.hidden_nf,latent_nf=args.n_dim)
    pars = vae.init(key,train_sample,init_key)

    optim = optax.adam(learning_rate=args.lr)
    opt_state = optim.init(pars)

    @jax.jit
    def compute_loss(pars,x,key):
        mu_x,log_var_x, mu_z, log_var_z = vae.apply(pars,x,key)
        kl_loss = - jnp.sum(0.5*(1+log_var_z-jnp.exp(log_var_z)-mu_z**2),axis=-1)
        reconstruction_loss = - jnp.sum(0.5*(1+log_var_x-jnp.exp(log_var_x)-(x-mu_x)**2),axis=-1)
        return jnp.mean(reconstruction_loss + args.beta*kl_loss), (jnp.mean(reconstruction_loss),jnp.mean(kl_loss))
    
    @jax.jit
    def update_params(pars,x,opt_state,key):
        (loss, losses), grad = jax.value_and_grad(compute_loss,has_aux=True)(pars,x,key)
        updates, opt_state = optim.update(grad,opt_state)
        pars = optax.apply_updates(pars,updates)
        return loss, losses, pars, opt_state
    
    best_loss = float("inf")
    best_pars = pars
    for i in range(args.epochs):
        train_loss = 0.0
        train_losses = jnp.zeros(2)
        for j, (batch_x,batch_dij) in enumerate(train_dataloader):
            key = random.split(key)[1]
            loss_batch, losses_batch, pars, opt_state = update_params(pars,batch_x,opt_state,key)
            train_loss += loss_batch
            train_losses += jnp.array(losses_batch)
        train_loss = train_loss/len(train_dataloader)
        train_losses = train_losses/len(train_dataloader)
        run.log({"epoch":i,"loss":train_loss,"reconstruction loss":train_losses[0],"kl loss":train_losses[1]})
        test_loss = 0.0
        test_losses = jnp.zeros(2)
        for k, (batch_x,batch_dij) in enumerate(test_dataloader):
            key = random.split(key)[0]
            loss_batch,losses_batch = compute_loss(pars,batch_x,key)
            test_loss += loss_batch
            test_losses += jnp.array(losses_batch)
        test_loss = test_loss/len(train_dataloader)
        test_losses = test_losses/len(train_dataloader)
        run.log({"epoch":i,"test loss":train_loss,"test reconstruction loss":train_losses[0],"test kl loss":train_losses[1]})
        if test_loss < best_loss:
            best_loss = test_loss
            best_pars = pars
    with open(os.path.join(wandb.run.dir, "parameters"), 'wb') as fp:
        pickle.dump(best_pars,fp)
    run.finish()




def train_CFM(args):
    run = wandb.init(project="PROJECT",config=args)

    key = random.PRNGKey(args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    
    x, mean = data.load_data(debug=True)
    x_normalized = (x - x.mean(axis=0))/x.std(axis=0)

    if args.split:
        train_idx = random.choice(rs_key,x_normalized.shape[0],(int(x_normalized.shape[0]*args.train_size),),replace=False)
        test_idx = jnp.setdiff1d(jnp.arange(x_normalized.shape[0]), train_idx)
        train_dataloader = data.ARCHGenDataLoader(x_normalized[train_idx],args,key=dataloader_key)
        test_dataloader = data.ARCHGenDataLoader(x_normalized[test_idx],args,key=dataloader_key)
    else:
        train_dataloader = data.ARCHGenDataLoader(x_normalized,args,key=dataloader_key)
        test_dataloader = train_dataloader
    train_sample = next(iter(train_dataloader))

    vector_field = model.VectorField(x_dim=x.shape[1],hidden_nf=args.hidden_nf,n_layers=args.n_layers)
    pars = vector_field.init(init_key,train_sample,0.0)

    lr_scheduler = optax.cosine_decay_schedule(args.lr,int(len(train_idx)/args.batch_size)*args.epochs,alpha=args.lr/100)
    optim = optax.adam(learning_rate=lr_scheduler)
    opt_state = optim.init(pars)

    @jax.jit
    def compute_loss(pars,x,key):
        uniform_key,normal_key = random.split(key)
        t = random.uniform(uniform_key)
        x0 = random.normal(normal_key,x.shape)
        x1 = x
        xt = (1-t)*x0 + t*x1
        ut = x1-x0
        vt = vector_field.apply(pars,xt,t)
        loss = jnp.mean(jnp.sum((vt - ut)**2,axis=-1))
        return loss
    
    @jax.jit
    def update_params(pars,x,opt_state,key):
        loss, grad = jax.value_and_grad(compute_loss)(pars,x,key)
        updates, opt_state = optim.update(grad,opt_state)
        pars = optax.apply_updates(pars,updates)
        return loss, pars, opt_state
    
    best_loss = float("inf")
    for i in range(args.epochs):
        epoch_train_loss = 0.0
        epoch_test_loss = 0.0
        for j, batch_x in enumerate(train_dataloader):
            key,random_key = random.split(key)
            loss, pars, opt_state = update_params(pars,batch_x,opt_state,random_key)
            epoch_train_loss += loss
        epoch_train_loss = epoch_train_loss/len(train_dataloader)
        run.log({"epoch":i,"loss":epoch_train_loss})
        for j, batch_x in enumerate(test_dataloader):
            key, random_key = random.split(key)
            loss = compute_loss(pars,batch_x,random_key)
            epoch_test_loss += loss
        epoch_test_loss = epoch_test_loss/len(test_dataloader)
        run.log({"epoch":i,"test loss":epoch_test_loss})
        if epoch_test_loss < best_loss:
            best_loss = epoch_test_loss
            best_pars = pars
    with open(os.path.join(wandb.run.dir, "parameters"), 'wb') as fp:
        pickle.dump(best_pars,fp)
    run.finish()


def train_PFM(args):
    run = wandb.init(project="PROJECT",config=args)

    key = random.PRNGKey(args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    
    x,mean = data.load_data(debug=True)

    with open(args.path, "rb") as fp:
        pars = pickle.load(fp)
    varphi = model.CNF(x_dim=x.shape[1],hidden_nf=args.varphi_hidden_nf,n_layers=args.varphi_n_layers,n_steps=args.varphi_n_steps,seed=args.seed)
    z,_,_,_ = varphi.apply(pars,x)
    z_normalized = (z - z.mean(axis=0))/z.std(axis=0)

    if args.split:
        train_idx = random.choice(rs_key,z_normalized.shape[0],(int(z.shape[0]*args.train_size),),replace=False)
        test_idx = jnp.setdiff1d(jnp.arange(z_normalized.shape[0]), train_idx)
        train_dataloader = data.ARCHGenDataLoader(z_normalized[train_idx],args,key=dataloader_key)
        test_dataloader = data.ARCHGenDataLoader(z_normalized[test_idx],args,key=dataloader_key)
    else:
        train_dataloader = data.ARCHGenDataLoader(z_normalized,args,key=dataloader_key)
        test_dataloader = train_dataloader
    train_sample = next(iter(train_dataloader))

    vector_field = model.VectorField(x_dim=z_normalized.shape[1],hidden_nf=args.hidden_nf,n_layers=args.n_layers)
    pars = vector_field.init(init_key,train_sample,0.0)

    lr_scheduler = optax.cosine_decay_schedule(args.lr, int(len(train_idx)/args.batch_size)*args.epochs,alpha=args.lr/100)
    optim = optax.adam(learning_rate=lr_scheduler)
    opt_state = optim.init(pars)

    @jax.jit
    def compute_loss(pars,z,key):    
        uniform_key,normal_key = random.split(key)
        t = random.uniform(uniform_key)
        z0 = random.normal(normal_key,z.shape)
        z1 = z
        zt = (1-t)*z0 + t*z1
        ut = z1-z0
        vt = vector_field.apply(pars,zt,t)
        loss = jnp.mean(jnp.sum((vt - ut)**2,axis=-1))
        return loss
    
    @jax.jit
    def update_params(pars,z,opt_state,key):
        loss, grad = jax.value_and_grad(compute_loss)(pars,z,key)
        updates, opt_state = optim.update(grad,opt_state)
        pars = optax.apply_updates(pars,updates)
        return loss, pars, opt_state
    
    best_loss = float("inf")
    for i in range(args.epochs):
        epoch_train_loss = 0.0
        epoch_test_loss = 0.0
        for j, batch_z in enumerate(train_dataloader):
            key,random_key = random.split(key)
            loss, pars, opt_state = update_params(pars,batch_z,opt_state,random_key)
            epoch_train_loss += loss
        epoch_train_loss = epoch_train_loss/len(train_dataloader)
        run.log({"epoch":i,"loss":epoch_train_loss})
        for j, batch_z in enumerate(test_dataloader):
            key,random_key = random.split(key)
            loss = compute_loss(pars,batch_z,random_key)
            epoch_test_loss += loss
        epoch_test_loss = epoch_test_loss/len(test_dataloader)
        run.log({"epoch":i,"test loss":epoch_test_loss})
        if epoch_test_loss < best_loss:
            best_loss = epoch_test_loss
            best_pars = pars
    with open(os.path.join(wandb.run.dir, "parameters"), 'wb') as fp:
        pickle.dump(best_pars,fp)
    run.finish()

def train_d_PFM(args):
    run = wandb.init(project="PROJECT",config=args)

    key = random.PRNGKey(args.seed)
    key, rs_key, dataloader_key, init_key = random.split(key,4)
    
    x,mean = data.load_data(debug=True)

    with open(args.path, "rb") as fp:
        pars = pickle.load(fp)
    varphi = model.CNF(x_dim=x.shape[1],hidden_nf=args.varphi_hidden_nf,n_layers=args.varphi_n_layers,n_steps=args.varphi_n_steps,seed=args.seed)
    z, _, _, _ = varphi.apply(pars,x)
    z_normalized = (z - z.mean(axis=0))/z.std(axis=0)

    if args.split:
        train_idx = random.choice(rs_key,z_normalized.shape[0],(int(z_normalized.shape[0]*args.train_size),),replace=False)
        test_idx = jnp.setdiff1d(jnp.arange(z_normalized.shape[0]), train_idx)
        train_dataloader = data.ARCHGenDataLoader(z_normalized[train_idx],args,key=dataloader_key)
        test_dataloader = data.ARCHGenDataLoader(z_normalized[test_idx],args,key=dataloader_key)
    else:
        train_dataloader = data.ARCHGenDataLoader(z_normalized,args,key=dataloader_key)
        test_dataloader = train_dataloader
    train_sample = next(iter(train_dataloader))

    vector_field = model.VectorField(x_dim=args.n_dim,hidden_nf=args.hidden_nf,n_layers=args.n_layers)
    pars = vector_field.init(init_key,train_sample[:,-args.n_dim:],0.0)

    lr_scheduler = optax.cosine_decay_schedule(args.lr, int(len(train_idx)/args.batch_size)*args.epochs,alpha=args.lr/100)
    optim = optax.adam(learning_rate=lr_scheduler)
    opt_state = optim.init(pars)

    @jax.jit
    def compute_loss(pars,z,key):
        uniform_key,normal_key = random.split(key)
        t = random.uniform(uniform_key)
        z0 = random.normal(normal_key,z.shape)
        z1 = z
        zt = (1-t)*z0 + t*z1
        ut = z1-z0
        vt = vector_field.apply(pars,zt,t)
        loss = jnp.mean(jnp.sum((vt - ut)**2,axis=-1))
        return loss
    
    @jax.jit
    def update_params(pars,z,opt_state,key):
        loss, grad = jax.value_and_grad(compute_loss)(pars,z,key)
        updates, opt_state = optim.update(grad,opt_state)
        pars = optax.apply_updates(pars,updates)
        return loss, pars, opt_state
    
    best_loss = float("inf")
    for i in range(args.epochs):
        epoch_train_loss = 0.0
        epoch_test_loss = 0.0
        for j, batch_z in enumerate(train_dataloader):
            key, random_key = random.split(key)
            loss, pars, opt_state = update_params(pars,batch_z[:,-args.n_dim:],opt_state,random_key)
            epoch_train_loss += loss
        epoch_train_loss = epoch_train_loss/len(train_dataloader)
        run.log({"epoch":i,"loss":epoch_train_loss})
        for j, batch_z in enumerate(test_dataloader):
            key, random_key = random.split(key)
            loss = compute_loss(pars,batch_z[:,-args.n_dim:],random_key)
            epoch_test_loss += loss
        epoch_test_loss = epoch_test_loss/len(test_dataloader)
        run.log({"epoch":i,"test loss":epoch_test_loss})
        if epoch_test_loss < best_loss:
            best_loss = epoch_test_loss
            best_pars = pars
    with open(os.path.join(wandb.run.dir, "parameters"), 'wb') as fp:
        pickle.dump(best_pars,fp)
    run.finish()